package websockethttp

import (
	"container/list"
	"encoding/json"
	"errors"
	"github.com/gorilla/websocket"
	"log"
	"net/http"
	"strconv"
	"time"
)

type Server struct {
	Name                   string                                     // 服务名
	requestTimeout         int64                                      // request到客户端的超时时间
	isShowPongLogs         bool                                       // 心跳包任务日志
	isClusterMode          bool                                       // 是否处于集群模式
	authorizationFunc      func(request *http.Request) (string, bool) // 连接授权处理函数
	writeBeforeFunc        func(data []byte) []byte                   // 写出（发送）之前（可以做压缩处理）
	readBeforeFunc         func(data []byte) []byte                   // 读取（接收）之前（可以做解压处理）
	defaultProcessFunc     func(context *ServerContext)               // 默认 Process 当找不到的时候调用这个处理
	connOpenEventFuncList  *list.List                                 // 连接新建连接监听器列表
	connCloseEventFuncList *list.List                                 // 连接关闭连接监听器列表
	serverRequestFilters   *list.List                                 // 服务端对客户端 请求过滤器列表
	serverResponseFilters  *list.List                                 // 客户端对服务端 响应过滤器列表
	clientRequestFilters   *list.List                                 // 客户端对服务端 请求过滤器列表
	clientResponseFilters  *list.List                                 // 服务端对客户端 响应过滤器列表
	requestCommonHeader    map[string]string                          // Request公共请求头
	responseCommonHeader   map[string]string                          // Response公共请求头
	clientChannelMap       map[string]*ClientConnChannel              // 保存连接渠道
	requestCallbackMap     map[string]*ServerRequestCallback          // 保存回调函数
	processFuncMap         map[string]func(context *ServerContext)    // Process 保存服务器处理器方法
	connTimeoutTicker      *time.Ticker                               // Connection timeout 任务
	requestTimeoutTicker   *time.Ticker                               // Request timeout 任务
	upgrader               *websocket.Upgrader                        // websocket.Upgrade
}

// ServerContext 封装请求上下文信息
type ServerContext struct {
	Channel  *ClientConnChannel `json:"channel"`
	Request  *SocketRequest     `json:"request"`
	Response *SocketResponse    `json:"response"`
	Error    error              `json:"error"`
}

// ServerRequestCallback 封装request异步回调的函数信息
type ServerRequestCallback struct {
	Context    *ServerContext           `json:"context"`
	Callback   func(ctx *ServerContext) `json:"callback"`
	CreateTime int64                    `json:"create_time"`
}

// ClientConnChannel 封装一个客户端连接信息
type ClientConnChannel struct {
	ConnName   string          // 客户端名称
	PongTime   int64           // 心跳最新时间
	ConnSocket *websocket.Conn // 连接对象
	BindServer *Server         // 绑定的服务
}

// UpgradeHttp 将http转为socket(兼容使用gin等第三方web框架)
func (server *Server) UpgradeHttp(writer http.ResponseWriter, request *http.Request) {
	// 调用校验函数处理校验
	name, aok := server.authorizationFunc(request)
	if !aok || len(name) <= 0 {
		writer.WriteHeader(200)
		_, _ = writer.Write([]byte("authorization error"))
		return
	}

	// 建立连接
	conn, err := server.upgrader.Upgrade(writer, request, nil)
	if err != nil {
		writer.WriteHeader(200)
		_, _ = writer.Write([]byte("upgrader error"))
		return
	}

	// 填充数据模型
	channel := new(ClientConnChannel)
	channel.ConnName = name // 在这里可以绑定用户ID
	channel.PongTime = time.Now().UnixMilli()
	channel.BindServer = server
	channel.ConnSocket = conn

	tempChannel, ook := server.clientChannelMap[name]
	if ook && tempChannel != nil {
		channel.CloseConnection(1000, "repetition_conn")
	}
	server.clientChannelMap[name] = channel // 保存连接信息

	// 回调连接监听器
	if server.connOpenEventFuncList.Len() > 0 {
		for element := server.connOpenEventFuncList.Front(); element != nil; element = element.Next() {
			element.Value.(func(channel *ClientConnChannel))(channel)
		}
	}

	conn.SetCloseHandler(func(code int, text string) error {
		channel.CloseConnection(code, text)
		return nil
	})

	go func() { server.readMessage(channel) }() // 新线程中监听
}

// LauncherServer 启动服务
func (server *Server) LauncherServer(path string, port int) error {
	http.HandleFunc(path, func(writer http.ResponseWriter, request *http.Request) {
		server.UpgradeHttp(writer, request)
	})
	return http.ListenAndServe(":"+strconv.Itoa(port), nil)
}

// 生成一个默认的Upgrader对象
func (server *Server) generateDefaultUpgrader() *websocket.Upgrader {
	upgrade := new(websocket.Upgrader)
	upgrade.HandshakeTimeout = 3600
	upgrade.ReadBufferSize = 1024
	upgrade.WriteBufferSize = 1024
	upgrade.EnableCompression = false
	upgrade.WriteBufferPool = nil
	upgrade.Subprotocols = nil
	upgrade.CheckOrigin = func(r *http.Request) bool { return true }
	upgrade.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
		log.Panicf("generateDefaultUpgrader error: status(%v) error(%v)", status, reason)
	}
	return upgrade
}

func (server *Server) readMessage(channel *ClientConnChannel) {
	for {
		mt, body, err := channel.ConnSocket.ReadMessage() // 读取客户端消息（会阻塞线程）
		if err != nil {
			channel.CloseConnection(1000, "error")
			return // 结束线程
		}
		switch mt {
		case websocket.TextMessage: // text消息(request走TextMessage)
			go func() {
				server.handlerRequest(channel, body)
			}()
			break
		case websocket.BinaryMessage: // byte消息(response走BinaryMessage)
			go func() {
				server.handlerResponse(channel, body)
			}()
			break
		}
	}
}

func (server *Server) handlerRequest(channel *ClientConnChannel, data []byte) {

	// 解压处理在这里加入
	msg := server.readBeforeFunc(data)

	// request
	request := new(SocketRequest)
	jsonUnmarshalErr := json.Unmarshal(msg, request)
	if jsonUnmarshalErr != nil {
		log.Printf("客户端发送的 request 数据格式错误 %v %v", jsonUnmarshalErr, string(msg))
		return
	}

	// response
	response := new(SocketResponse)
	response.Uid = request.Uid
	response.Header = make(map[string]string)
	// 绑定 common header
	for k, v := range server.responseCommonHeader {
		response.Header[k] = v
	}

	// context
	context := new(ServerContext)
	context.Channel = channel
	context.Request = request
	context.Response = response

	// 应用request过滤器处理
	for element := server.clientRequestFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(ctx *ServerContext) bool)(context) {
			return
		}
	}

	// 查找 process
	processFunc, hok := server.processFuncMap[request.Process]
	if hok && processFunc != nil {
		processFunc(context) // 调用process处理请求
	} else {
		server.defaultProcessFunc(context) // 调用默认process处理
	}

	// 应用response过滤器处理
	for element := server.clientResponseFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(ctx *ServerContext) bool)(context) {
			return
		}
	}

	jsonByte, jsonMarshalErr := json.Marshal(response)
	if jsonMarshalErr != nil {
		log.Printf("handlerRequest: response转为json异常 %v", jsonMarshalErr)
	}

	// 压缩处理在这里加入
	requestJsonByte := server.writeBeforeFunc(jsonByte)

	_ = writeResponseMessage(channel.ConnSocket, requestJsonByte)
}

func (server *Server) handlerResponse(channel *ClientConnChannel, data []byte) {

	// 解压处理在这里加入
	msg := server.readBeforeFunc(data)

	response := new(SocketResponse)
	jsonUnmarshalErr := json.Unmarshal(msg, response)
	if jsonUnmarshalErr != nil {
		log.Printf("handlerResponse: json转为response错误 %v", jsonUnmarshalErr)
		return
	}

	rc, ok := server.requestCallbackMap[response.Uid]
	if !ok || rc == nil {
		log.Printf("handlerResponse: 找不到对应的 callback(%v)", string(msg))
		return
	}

	defer delete(server.requestCallbackMap, response.Uid)

	// 还原 response
	rc.Context.Response.Header = response.Header
	rc.Context.Response.Code = response.Code
	rc.Context.Response.Message = response.Message
	rc.Context.Response.Body = response.Body

	rc.Context.Channel = channel // 这个其实不需要

	// 执行回调
	rc.Callback(rc.Context)

	// 应用response过滤器处理
	for element := server.serverResponseFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(ctx *ServerContext) bool)(rc.Context) {
			return
		}
	}
}

// 发送消息不推荐直接使用
func (server *Server) sendRequest(channel *ClientConnChannel, request *SocketRequest, callback func(c *ServerContext)) {
	// 合并 common 到 header
	mergeHeader := make(map[string]string)
	for k, m := range server.requestCommonHeader {
		mergeHeader[k] = m
	}
	for k, m := range request.Header {
		mergeHeader[k] = m
	}
	request.Header = mergeHeader

	// response
	response := new(SocketResponse)
	response.Uid = request.Uid

	// context
	context := new(ServerContext)
	context.Channel = channel
	context.Request = request
	context.Response = response

	// 保存回调函数等待对方回调
	if callback != nil {
		server.requestCallbackMap[request.Uid] = &ServerRequestCallback{
			Context:    context,
			Callback:   callback,
			CreateTime: time.Now().UnixMilli(),
		}
	}

	// 应用request过滤器处理
	for element := server.serverRequestFilters.Front(); element != nil; element = element.Next() {
		if element.Value.(func(c *ServerContext) bool)(context) {
			return
		}
	}

	// struct 序列化成 json 字符串
	jsonByte, jsonMarshalErr := json.Marshal(request)
	if jsonMarshalErr != nil {
		log.Printf("handlerRequest: request转为json异常 %v", jsonMarshalErr)
	}

	// 压缩处理在这里加入
	requestJsonByte := server.writeBeforeFunc(jsonByte)

	// 执行发送操作并将错误信息返回
	context.Error = writeRequestMessage(channel.ConnSocket, requestJsonByte)
}

// 启用心跳处理函数（心跳处理函数的名称是与客户端约定好的）
func (server *Server) enableHealthCheckProcess() {
	server.RegisterProcessFunc(PongProcessName, func(context *ServerContext) {
		if server.isShowPongLogs {
			log.Printf("Health: name(%s) body(%s)", context.Channel.ConnName, context.Request.Body)
		}
		context.Channel.PongTime = time.Now().UnixMilli() // 更新连接的活跃时间记录
		context.Response.Code = PongProcessSuccess
		context.Response.Message = "ok"
	})
}

// 检查保存连接是否有失效的
func (server *Server) openConnTimeoutCheck() {
	if server.connTimeoutTicker != nil {
		server.connTimeoutTicker.Stop()
	}
	server.connTimeoutTicker = time.NewTicker(time.Millisecond * time.Duration(PongProcessTime))
	for range server.connTimeoutTicker.C {
		nowTime := time.Now().UnixMilli()
		for _, channel := range server.clientChannelMap {
			if channel.PongTime < (nowTime - (PongProcessTime * 3)) {
				go func() {
					channel.CloseConnection(RequestTimeoutCode, RequestTimeoutMsg)
				}()
			}
		}
	}
}

// 检查回调函数是否有失效的
func (server *Server) openRequestTimeoutCheck() {
	if server.requestTimeout <= 0 {
		server.requestTimeout = RequestDefaultTimeout
	}
	if server.requestTimeoutTicker != nil {
		server.requestTimeoutTicker.Stop()
	}
	server.requestTimeoutTicker = time.NewTicker(time.Millisecond * time.Duration(server.requestTimeout))
	for range server.requestTimeoutTicker.C {
		nowTime := time.Now().UnixMilli()
		for key, callback := range server.requestCallbackMap {
			if callback.CreateTime < (nowTime - server.requestTimeout) {
				go func() {
					callback.Context.Response.Code = RequestTimeoutCode
					callback.Context.Response.Message = RequestTimeoutMsg
					callback.Callback(callback.Context)
				}()
				delete(server.requestCallbackMap, key)
			}
		}
	}
}

// =====================================================================================================================

// SetAuthorizationFunc 设置授权处理函数，返回: 连接名称/认证结果
func (server *Server) SetAuthorizationFunc(fun func(request *http.Request) (string, bool)) {
	server.authorizationFunc = fun
}

// AddConnOpenEventFunc 设置连接监听
func (server *Server) AddConnOpenEventFunc(fun func(channel *ClientConnChannel)) {
	server.connOpenEventFuncList.PushBack(fun)
}

// AddConnCloseEventFunc 设置连接监听
func (server *Server) AddConnCloseEventFunc(fun func(channel *ClientConnChannel)) {
	server.connCloseEventFuncList.PushBack(fun)
}

// SetRequestTimeout 设置超时
func (server *Server) SetRequestTimeout(timeout int64) {
	server.requestTimeout = timeout
}

// ShowPongLogs 显示心跳日志
func (server *Server) ShowPongLogs(isShowPongLogs bool) {
	server.isShowPongLogs = isShowPongLogs
}

// RegisterProcessFunc 注册请求处理器（所有的业务应该在Process中处理）
func (server *Server) RegisterProcessFunc(processFuncName string, processFunc func(context *ServerContext)) {
	server.processFuncMap[processFuncName] = processFunc
}

// GetClientChannel 获取连接（name 为 authorizationFunc 的返回值）
func (server *Server) GetClientChannel(name string) (*ClientConnChannel, error) {
	channel, ok := server.clientChannelMap[name]
	if channel != nil && ok {
		return channel, nil
	} else {
		return nil, errors.New("ClientConnChannel it doesn't exist")
	}
}

// GetAllClientChannel 获取全部连接
func (server *Server) GetAllClientChannel() map[string]*ClientConnChannel {
	return server.clientChannelMap
}

// AddClientRequestFilter 添加Request过滤器函数（仅针对客户端对服务端的请求有效）
func (server *Server) AddClientRequestFilter(filter func(context *ServerContext) bool) {
	server.clientRequestFilters.PushBack(filter)
}

// AddClientResponseFilter 添加Response过滤器函数（仅针对服务端对客户端的响应有效）
func (server *Server) AddClientResponseFilter(filter func(context *ServerContext) bool) {
	server.clientResponseFilters.PushBack(filter)
}

// AddServerRequestFilter 添加Request过滤器函数（仅针对服务端对客户端的请求有效）
func (server *Server) AddServerRequestFilter(filter func(context *ServerContext) bool) {
	server.serverRequestFilters.PushBack(filter)
}

// AddServerResponseFilter 添加Response过滤器函数（仅针对客户端对服务端的响应有效）
func (server *Server) AddServerResponseFilter(filter func(context *ServerContext) bool) {
	server.serverResponseFilters.PushBack(filter)
}

// SetRequestCommonHeader 设置 Common Request Header
func (server *Server) SetRequestCommonHeader(header map[string]string) {
	server.requestCommonHeader = header
}

// SetResponseCommonHeader 设置 Common Request Header
func (server *Server) SetResponseCommonHeader(header map[string]string) {
	server.responseCommonHeader = header
}

// SetWriteBeforeFunc 消息传送之前处理函数
func (server *Server) SetWriteBeforeFunc(fun func(data []byte) []byte) {
	server.writeBeforeFunc = fun
}

// SetReadBeforeFunc 设置 websocket.Upgrader
func (server *Server) SetReadBeforeFunc(fun func(data []byte) []byte) {
	server.readBeforeFunc = fun
}

// SetWebsocketUpgrader 消息接收之前处理函数
func (server *Server) SetWebsocketUpgrader(upgrader *websocket.Upgrader) {
	server.upgrader = upgrader
}

// SetDefaultProcessFunc 默认Process函数
func (server *Server) SetDefaultProcessFunc(fun func(context *ServerContext)) {
	server.defaultProcessFunc = fun
}

// =====================================================================================================================

// SendMessage 发送消息（异步）
func (channel *ClientConnChannel) SendMessage(process string, header map[string]string, body string, callback func(context *ServerContext)) {
	request := new(SocketRequest)
	request.Uid = NewMessageId()
	request.Process = process
	request.Header = header
	request.Body = body
	if request.Header != nil {
		request.Header = make(map[string]string)
	}
	channel.BindServer.sendRequest(channel, request, callback)
}

// SendMessageSync 发送消息（同步）
func (channel *ClientConnChannel) SendMessageSync(process string, header map[string]string, body string) (*ServerContext, error) {
	var await = make(chan *ServerContext)
	defer safeCloseServerContextChannel(await)
	channel.SendMessage(process, header, body, func(ctx *ServerContext) {
		await <- ctx
	})
	ctx := <-await
	return ctx, ctx.Error
}

// CloseConnection 关闭渠道（连接）
func (channel *ClientConnChannel) CloseConnection(code int, message string) {
	_, has := channel.BindServer.clientChannelMap[channel.ConnName]
	if has {
		defer delete(channel.BindServer.clientChannelMap, channel.ConnName)
	}
	err := channel.ConnSocket.Close()
	log.Printf("连接关闭: tags(%v) code(%v) msg(%v) err(%v)", channel.ConnName, code, message, err)
	// 回调关闭监听器
	if channel.BindServer.connCloseEventFuncList.Len() > 0 {
		for element := channel.BindServer.connCloseEventFuncList.Front(); element != nil; element = element.Next() {
			element.Value.(func(channel *ClientConnChannel))(channel)
		}
	}
}

// =====================================================================================================================

// ServerNew 新建一个服务实例
func ServerNew(serverName string) *Server {
	server := new(Server)

	server.Name = serverName
	server.requestTimeout = RequestDefaultTimeout
	server.isShowPongLogs = false
	server.connOpenEventFuncList = list.New()
	server.connCloseEventFuncList = list.New()
	server.serverRequestFilters = list.New()
	server.serverResponseFilters = list.New()
	server.clientRequestFilters = list.New()
	server.clientResponseFilters = list.New()
	server.requestCommonHeader = make(map[string]string)
	server.clientChannelMap = make(map[string]*ClientConnChannel)
	server.requestCallbackMap = make(map[string]*ServerRequestCallback)
	server.processFuncMap = make(map[string]func(context *ServerContext))
	server.upgrader = server.generateDefaultUpgrader()
	server.defaultProcessFunc = func(context *ServerContext) {
		log.Printf("找不到对应的 Process %v", context.Request.Process)
	}
	server.writeBeforeFunc = func(data []byte) []byte { return data }
	server.readBeforeFunc = func(data []byte) []byte { return data }

	server.enableHealthCheckProcess()

	go func() { server.openConnTimeoutCheck() }()
	go func() { server.openRequestTimeoutCheck() }()

	return server
}
